
### Maximum likelihood estimation
## Input: 
##      Z,D,Y,X, X.phi
##      Target parameter: "LATE" or "MLATE"
## Output: 
##      Estimate of alpha    

MLEst = function(Z, D, Y, X, X.phi = X, target = "LATE"){
  
  pX = dim(X)[2]
  
  
  ## negative log likelihood function
  nLL = function(pars){
    
    if(target == "LATE")   theta = tanh(X %*% pars[1:pX])
    if(target == "MLATE")  theta = exp(X %*% pars[1:pX])
    phi1      = expit(X.phi %*% pars[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% pars[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% pars[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% pars[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% pars[(5*pX+1):(6*pX)])
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  nLL1 = function(pars1){
    if(target == "LATE")   theta = tanh(X %*% pars1)
    if(target == "MLATE")  theta = exp(X %*% pars1)
    phi1      = expit(X.phi %*% pars[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% pars[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% pars[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% pars[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% pars[(5*pX+1):(6*pX)])
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  nLL2 = function(pars2){
    if(target == "LATE")   theta = tanh(X %*% pars[1:pX])
    if(target == "MLATE")  theta = exp(X %*% pars[1:pX])
    phi1      = expit(X.phi %*% pars2)
    phi2      = expit(X.phi %*% pars[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% pars[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% pars[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% pars[(5*pX+1):(6*pX)])
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  nLL3 = function(pars3){
    if(target == "LATE")   theta = tanh(X %*% pars[1:pX])
    if(target == "MLATE")  theta = exp(X %*% pars[1:pX])
    phi1      = expit(X.phi %*% pars[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% pars3)
    phi3      = expit(X.phi %*% pars[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% pars[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% pars[(5*pX+1):(6*pX)])
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  nLL4 = function(pars4){
    if(target == "LATE")   theta = tanh(X %*% pars[1:pX])
    if(target == "MLATE")  theta = exp(X %*% pars[1:pX])
    phi1      = expit(X.phi %*% pars[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% pars[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% pars4)
    phi4      = expit(X.phi %*% pars[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% pars[(5*pX+1):(6*pX)])
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  nLL5 = function(pars5){
    if(target == "LATE")   theta = tanh(X %*% pars[1:pX])
    if(target == "MLATE")  theta = exp(X %*% pars[1:pX])
    phi1      = expit(X.phi %*% pars[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% pars[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% pars[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% pars5)
    op        = exp(X.phi %*% pars[(5*pX+1):(6*pX)])
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  nLL6 = function(pars6){
    if(target == "LATE")   theta = tanh(X %*% pars[1:pX])
    if(target == "MLATE")  theta = exp(X %*% pars[1:pX])
    phi1      = expit(X.phi %*% pars[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% pars[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% pars[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% pars[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% pars6)
    p         = getProb(theta, phi1, phi2, phi3, phi4, op, target)
    
    pZ        = p[,c(1,3,5,7)]
    pZ[Z==1,] = p[Z==1,c(2,4,6,8)]
    group1    = D == 0 & Y == 0
    group2    = D == 0 & Y == 1
    group3    = D == 1 & Y == 0
    group4    = D == 1 & Y == 1
    pDY_Z     = pZ[,1]
    pDY_Z[group2] = pZ[group2,2]
    pDY_Z[group3] = pZ[group3,3]
    pDY_Z[group4] = pZ[group4,4]
    
    return(-sum(log(pDY_Z)))
  }
  
  a = list()
  a = tryCatch({
    
    para.start1 = rep(0, 6 * pX)
    opt = optim(para.start1, nLL, control=list(maxit=10000))
    pars = opt$par
    
    
    diff = 1; value.old = nLL(para.start1); step1 = 0
    while(diff > 1e-6 & step1 < 100){
      step1 = step1 + 1
      # cat("step:", step, "  diff:", RoundS(diff,2), "nll:", value.old, "\n")
      
      opt1  = optim(pars[1:pX], nLL1, control=list(maxit=10000))
      pars[1:pX] = opt1$par
      opt2  = optim(pars[(pX+1):(2*pX)], nLL2, control=list(maxit=10000))
      pars[(pX+1):(2*pX)] = opt2$par
      opt3  = optim(pars[(2*pX+1):(3*pX)], nLL3, control=list(maxit=10000))
      pars[(2*pX+1):(3*pX)] = opt3$par
      opt4  = optim(pars[(3*pX+1):(4*pX)], nLL4, control=list(maxit=10000))
      pars[(3*pX+1):(4*pX)] = opt4$par
      opt5  = optim(pars[(4*pX+1):(5*pX)], nLL5, control=list(maxit=10000))
      pars[(4*pX+1):(5*pX)] = opt5$par
      opt6  = optim(pars[(5*pX+1):(6*pX)], nLL6, control=list(maxit=10000))
      pars[(5*pX+1):(6*pX)] = opt6$par
      
      diff = (value.old - opt6$value) / opt6$value
      value.old  = opt6$value
    }
    
    
    
    list(val = pars, error = NA)
  },
  
  error = function(e) { 
    list(val = rep(NA, 6 * pX), error = e)
  })
  
  
  return(a)
  
}

    
    
    
    
    
    
    
